{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SageMaker JumpStart - Deploy Chronos endpoints to AWS for production use"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this demo notebook, we will walk through the process of using the **SageMaker Python SDK** to deploy a **Chronos** model to a cloud endpoint on AWS. To simplify deployment, we will leverage **SageMaker JumpStart**.\n",
    "\n",
    "### Why Deploy to an Endpoint?\n",
    "So far, we’ve seen how to run models locally, which is useful for experimentation. However, in a production setting, a forecasting model is typically just one component of a larger system. Running models locally doesn’t scale well and lacks the reliability needed for real-world applications.\n",
    "\n",
    "To address this, we deploy models as **endpoints** on AWS. An endpoint acts as a **hosted service**—we can send it requests (containing time series data), and it returns forecasts in response. This allows seamless integration into production workflows, ensuring scalability and real-time inference."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Deploy the model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, update the SageMaker SDK to access the latest models:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -U -q sagemaker"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We create a `JumpStartModel` with the necessary configuration based on the model ID. The key parameters are:\n",
    "- `model_id`: Specifies the model to use. Here, we choose the [Chronos-Bolt (Base)](https://huggingface.co/amazon/chronos-bolt-base) model. Currently, the following model IDs are supported:\n",
    "  - `autogluon-forecasting-chronos-bolt-base` - [Chronos-Bolt (Base)](https://huggingface.co/amazon/chronos-bolt-base).\n",
    "  - `autogluon-forecasting-chronos-bolt-small` - [Chronos-Bolt (Small)](https://huggingface.co/amazon/chronos-bolt-small).\n",
    "  - [Original Chronos models](https://huggingface.co/amazon/chronos-t5-small) in sizes `small`, `base` and `large` can be accessed, e.g., as `autogluon-forecasting-chronos-t5-small`. Note that these models require a GPU to run, are much slower and don't support covariates. Therefore, for most practical purposes we recommend using Chronos-Bolt models instead.\n",
    "- `instance_type`: Defines the AWS instance for serving the endpoint. We use `ml.c5.2xlarge` to run the model on CPU. To use a GPU, select an instance like `ml.g5.2xlarge`, or choose other CPU options such as `ml.m5.xlarge` or `ml.m5.4xlarge`. You can check the pricing for different SageMaker instance types for real-time inference [here](https://aws.amazon.com/sagemaker-ai/pricing/).\n",
    "\n",
    "The `JumpStartModel` will automatically set the necessary attributes such as `image_uri` based on the chosen `model_id` and `instance_type`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.jumpstart.model import JumpStartModel\n",
    "\n",
    "model_id = \"autogluon-forecasting-chronos-bolt-base\"\n",
    "\n",
    "model = JumpStartModel(\n",
    "    model_id=model_id,\n",
    "    instance_type=\"ml.c5.2xlarge\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we deploy the model and create an endpoint. Deployment typically takes a few minutes, as SageMaker provisions the instance, loads the model, and sets up the endpoint for inference.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictor = model.deploy()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> **Note:** Once the endpoint is deployed, it remains active and incurs charges on your AWS account until it is deleted. The cost depends on factors such as the instance type, the region where the endpoint is hosted, and the duration it remains running. To avoid unnecessary charges, make sure to delete the endpoint when it is no longer needed. For detailed pricing information, refer to the [SageMaker AI pricing page](https://aws.amazon.com/sagemaker-ai/pricing/).\n",
    "\n",
    "\n",
    "If the previous step results in an error, you may need to update the model configuration. For example, specifying a `role` when creating the `JumpStartModel` ensures the necessary AWS resources are accessible."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# model = JumpStartModel(role=\"your-sagemaker-execution-role\", model_id=model_id, instance_type=\"ml.c5.2xlarge\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Alternatively, you can create a predictor for an existing endpoint."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from sagemaker.predictor import retrieve_default\n",
    "\n",
    "# endpoint_name = \"NAME-OF-EXISTING-ENDPOINT\"\n",
    "# predictor = retrieve_default(endpoint_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Querying the endpoint"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can now invoke the endpoint to make a forecast. We send a **payload** to the endpoint, which includes historical time series values and configuration parameters, such as the prediction length. The endpoint processes this input and returns a **response** containing the forecasted values based on the provided data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define a utility function to print the response in a pretty format\n",
    "from pprint import pformat\n",
    "\n",
    "\n",
    "def nested_round(data, decimals=2):\n",
    "    \"\"\"Round numbers, including nested dicts and list.\"\"\"\n",
    "    if isinstance(data, float):\n",
    "        return round(data, decimals)\n",
    "    elif isinstance(data, list):\n",
    "        return [nested_round(item, decimals) for item in data]\n",
    "    elif isinstance(data, dict):\n",
    "        return {key: nested_round(value, decimals) for key, value in data.items()}\n",
    "    else:\n",
    "        return data\n",
    "\n",
    "\n",
    "def pretty_format(data):\n",
    "    return pformat(nested_round(data), width=150, sort_dicts=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "07605824",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'predictions': [{'mean': [-1.58, 0.52, 1.88, 1.39, -1.03, -3.34, -2.67, -0.64, 0.96, 1.59],\n",
      "                  '0.1': [-4.17, -2.71, -1.7, -2.35, -4.79, -6.98, -6.59, -4.87, -3.45, -2.89],\n",
      "                  '0.5': [-1.58, 0.52, 1.88, 1.39, -1.03, -3.34, -2.67, -0.64, 0.96, 1.59],\n",
      "                  '0.9': [1.47, 4.47, 6.27, 5.98, 3.5, 1.11, 2.06, 4.47, 6.41, 7.17]}]}\n"
     ]
    }
   ],
   "source": [
    "payload = {\n",
    "    \"inputs\": [\n",
    "        {\"target\": [0.0, 4.0, 5.0, 1.5, -3.0, -5.0, -3.0, 1.5, 5.0, 4.0, 0.0, -4.0, -5.0, -1.5, 3.0, 5.0, 3.0, -1.5, -5.0, -4.0]},\n",
    "    ],\n",
    "    \"parameters\": {\n",
    "        \"prediction_length\": 10\n",
    "    }\n",
    "}\n",
    "response = predictor.predict(payload)\n",
    "print(pretty_format(response))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A payload may also contain **multiple time series**, potentially including `start` and `item_id` fields."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d476c397",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'predictions': [{'mean': [1.41, 1.5, 1.49, 1.45, 1.51],\n",
      "                  '0.1': [0.12, -0.08, -0.25, -0.41, -0.45],\n",
      "                  '0.5': [1.41, 1.5, 1.49, 1.45, 1.51],\n",
      "                  '0.9': [3.29, 3.82, 4.09, 4.3, 4.56],\n",
      "                  'item_id': 'product_A',\n",
      "                  'start': '2024-01-01T10:00:00'},\n",
      "                 {'mean': [-1.22, -1.3, -1.3, -1.14, -1.13],\n",
      "                  '0.1': [-4.51, -5.48, -6.12, -6.5, -7.1],\n",
      "                  '0.5': [-1.22, -1.3, -1.3, -1.14, -1.13],\n",
      "                  '0.9': [2.84, 4.02, 4.92, 5.99, 6.79],\n",
      "                  'item_id': 'product_B',\n",
      "                  'start': '2024-02-02T10:00:00'}]}\n"
     ]
    }
   ],
   "source": [
    "payload = {\n",
    "    \"inputs\": [\n",
    "        {\n",
    "            \"target\": [1.0, 2.0, 3.0, 2.0, 0.5, 2.0, 3.0, 2.0, 1.0],\n",
    "            \"item_id\": \"product_A\",\n",
    "            \"start\": \"2024-01-01T01:00:00\",\n",
    "        },\n",
    "        {\n",
    "            \"target\": [5.4, 3.0, 3.0, 2.0, 1.5, 2.0, -1.0],\n",
    "            \"item_id\": \"product_B\",\n",
    "            \"start\": \"2024-02-02T03:00:00\",\n",
    "        },\n",
    "    ],\n",
    "    \"parameters\": {\n",
    "        \"prediction_length\": 5,\n",
    "        \"freq\": \"1h\",\n",
    "        \"quantile_levels\": [0.1, 0.5, 0.9],\n",
    "        \"batch_size\": 2,\n",
    "    },\n",
    "}\n",
    "response = predictor.predict(payload)\n",
    "print(pretty_format(response))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Chronos-Bolt models also support forecasting with covariates (a.k.a. exogenous features or related time series). These can be provided using the `past_covariates` and `future_covariates` keys."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'predictions': [{'mean': [1.41, 1.5, 1.49], '0.1': [0.12, -0.08, -0.25], '0.5': [1.41, 1.5, 1.49], '0.9': [3.29, 3.82, 4.09]},\n",
      "                 {'mean': [-1.22, -1.3, -1.3], '0.1': [-4.51, -5.48, -6.12], '0.5': [-1.22, -1.3, -1.3], '0.9': [2.84, 4.02, 4.92]}]}\n"
     ]
    }
   ],
   "source": [
    "payload = {\n",
    "    \"inputs\": [\n",
    "        {\n",
    "            \"target\": [1.0, 2.0, 3.0, 2.0, 0.5, 2.0, 3.0, 2.0, 1.0],\n",
    "            # past_covariates must have the same length as \"target\"\n",
    "            \"past_covariates\": {\n",
    "                \"feat_1\": [3.0, 6.0, 9.0, 6.0, 1.5, 6.0, 9.0, 6.0, 3.0],\n",
    "                \"feat_2\": [\"A\", \"B\", \"B\", \"B\", \"A\", \"A\", \"A\", \"A\", \"B\"],\n",
    "            },\n",
    "            # future_covariates must have length equal to \"prediction_length\"\n",
    "            \"future_covariates\": {\n",
    "                \"feat_1\": [2.5, 2.2, 3.3],\n",
    "                \"feat_2\": [\"B\", \"A\", \"A\"],\n",
    "            },\n",
    "        },\n",
    "        {\n",
    "            \"target\": [5.4, 3.0, 3.0, 2.0, 1.5, 2.0, -1.0],\n",
    "            \"past_covariates\": {\n",
    "                \"feat_1\": [0.6, 1.2, 1.8, 1.2, 0.3, 1.2, 1.8],\n",
    "                \"feat_2\": [\"A\", \"B\", \"B\", \"B\", \"A\", \"A\", \"A\"],\n",
    "            },\n",
    "            \"future_covariates\": {\n",
    "                \"feat_1\": [1.2, 0.3, 4.4],\n",
    "                \"feat_2\": [\"A\", \"B\", \"A\"],\n",
    "            },\n",
    "        },\n",
    "    ],\n",
    "    \"parameters\": {\n",
    "        \"prediction_length\": 3,\n",
    "        \"quantile_levels\": [0.1, 0.5, 0.9],\n",
    "    },\n",
    "}\n",
    "response = predictor.predict(payload)\n",
    "print(pretty_format(response))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Endpoint API\n",
    "So far, we have explored several examples of querying the endpoint with different payload structures. Below is a comprehensive API specification detailing all supported parameters, their meanings, and how they affect the model’s predictions.\n",
    "\n",
    "* **inputs** (required): List with at most 1000 time series that need to be forecasted. Each time series is represented by a dictionary with the following keys:\n",
    "    * **target** (required): List of observed numeric time series values. \n",
    "        - It is recommended that each time series contains at least 30 observations.\n",
    "        - If any time series contains fewer than 5 observations, an error will be raised.\n",
    "    * **item_id**: String that uniquely identifies each time series. \n",
    "        - If provided, the ID must be unique for each time series.\n",
    "        - If provided, then the endpoint response will also include the **item_id** field for each forecast.\n",
    "    * **start**: Timestamp of the first time series observation in ISO format (`YYYY-MM-DD` or `YYYY-MM-DDThh:mm:ss`). \n",
    "        - If **start** field is provided, then **freq** must also be provided as part of **parameters**.\n",
    "        - If provided, then the endpoint response will also include the **start** field indicating the first timestamp of each forecast.\n",
    "    * **past_covariates**: Dictionary containing the past values of the covariates for this time series.\n",
    "        - If **past_covariates** field is provided, then **future_covariates** must be provided as well with the same keys.\n",
    "        - Each key in **past_covariates** correspond to the name of the covariate. Each value must be an array consisting of all-numeric or all-string values, with the length equal to the length of the **target**.\n",
    "    * **future_covariates**: Dictionary containing the future values of the covariates for this time series (values during the forecast horizon).\n",
    "        - If **future_covariates** field is provided, then **past_covariates** must be provided as well with the same keys.\n",
    "        - Each key in **future_covariates** correspond to the name of the covariate. Each value must be an array consisting of all-numeric or all-string values, with the length equal to **prediction_length**.\n",
    "        - If both **past_covariates** and **future_covariates** are provided, a regression model specified by **covariate_model** will be used to incorporate the covariate information into the forecast.\n",
    "* **parameters**: Optional parameters to configure the model.\n",
    "    * **prediction_length**: Integer corresponding to the number of future time series values that need to be predicted. Defaults to `1`.\n",
    "        - Recommended to keep prediction_length <= 64 since larger values will result in inaccurate quantile forecasts. Values above 1000 will raise an error.\n",
    "    * **quantile_levels**: List of floats in range (0, 1) specifying which quantiles should should be included in the probabilistic forecast. Defaults to `[0.1, 0.5, 0.9]`. \n",
    "        - Note that Chronos-Bolt cannot produce quantiles outside the [0.1, 0.9] range (predictions outside the range will be clipped).\n",
    "    * **freq**: Frequency of the time series observations in [pandas-compatible format](https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases). For example, `1h` for hourly data or `2W` for bi-weekly data. \n",
    "        - If **freq** is provided, then **start** must also be provided for each time series in **inputs**.\n",
    "    * **batch_size**: Number of time series processed in parallel by the model. Larger values speed up inference but may lead to out of memory errors. Defaults to `256`.\n",
    "    * **covariate_model**: Name of the tabular regression model applied to the covariates. Possible options: `GBM` (LightGBM), `LR` (linear regression), `RF` (random forest), `CAT` (CatBoost), `XGB` (XGBoost). Defaults to `GBM`.\n",
    "\n",
    "All keys not marked with (required) are optional.\n",
    "\n",
    "The endpoint response contains the probabilistic (quantile) forecast for each time series included in the request."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Working with long-format data frames"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The endpoint communicates using JSON format for both input and output. However, in practice, time series data is often stored in a **long-format data frame** (where each row represents a timestamp for a specific item).\n",
    "\n",
    "In the following example, we demonstrate how to:\n",
    "\n",
    "1. Convert a long-format data frame into the JSON payload format required by the endpoint.\n",
    "2. Send the request and retrieve predictions.\n",
    "3. Convert the response back into a long-format data frame for further analysis.\n",
    "\n",
    "First, we load an example dataset in long data frame format."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>item_id</th>\n",
       "      <th>timestamp</th>\n",
       "      <th>scaled_price</th>\n",
       "      <th>promotion_email</th>\n",
       "      <th>promotion_homepage</th>\n",
       "      <th>unit_sales</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1062_101</td>\n",
       "      <td>2018-01-01</td>\n",
       "      <td>0.879130</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>636.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1062_101</td>\n",
       "      <td>2018-01-08</td>\n",
       "      <td>0.994517</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>123.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1062_101</td>\n",
       "      <td>2018-01-15</td>\n",
       "      <td>1.005513</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>391.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1062_101</td>\n",
       "      <td>2018-01-22</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>339.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1062_101</td>\n",
       "      <td>2018-01-29</td>\n",
       "      <td>0.883309</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>661.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    item_id  timestamp  scaled_price  promotion_email  promotion_homepage  \\\n",
       "0  1062_101 2018-01-01      0.879130              0.0                 0.0   \n",
       "1  1062_101 2018-01-08      0.994517              0.0                 0.0   \n",
       "2  1062_101 2018-01-15      1.005513              0.0                 0.0   \n",
       "3  1062_101 2018-01-22      1.000000              0.0                 0.0   \n",
       "4  1062_101 2018-01-29      0.883309              0.0                 0.0   \n",
       "\n",
       "   unit_sales  \n",
       "0       636.0  \n",
       "1       123.0  \n",
       "2       391.0  \n",
       "3       339.0  \n",
       "4       661.0  "
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "\n",
    "df = pd.read_csv(\n",
    "    \"https://autogluon.s3.amazonaws.com/datasets/timeseries/grocery_sales/test.csv\",\n",
    "    parse_dates=[\"timestamp\"],\n",
    ")\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We split the data into two parts:\n",
    "- Past data, including historic values of the target column and the covariates.\n",
    "- Future data that contains the future values of the covariates during the forecast horizon."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "prediction_length = 8\n",
    "target_col = \"unit_sales\"\n",
    "freq = pd.infer_freq(df[df.item_id == df.item_id[0]][\"timestamp\"])\n",
    "\n",
    "past_df = df.groupby(\"item_id\").head(-prediction_length)\n",
    "future_df = df.groupby(\"item_id\").tail(prediction_length).drop(columns=[target_col])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>item_id</th>\n",
       "      <th>timestamp</th>\n",
       "      <th>scaled_price</th>\n",
       "      <th>promotion_email</th>\n",
       "      <th>promotion_homepage</th>\n",
       "      <th>unit_sales</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1062_101</td>\n",
       "      <td>2018-01-01</td>\n",
       "      <td>0.879130</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>636.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1062_101</td>\n",
       "      <td>2018-01-08</td>\n",
       "      <td>0.994517</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>123.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1062_101</td>\n",
       "      <td>2018-01-15</td>\n",
       "      <td>1.005513</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>391.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1062_101</td>\n",
       "      <td>2018-01-22</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>339.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1062_101</td>\n",
       "      <td>2018-01-29</td>\n",
       "      <td>0.883309</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>661.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    item_id  timestamp  scaled_price  promotion_email  promotion_homepage  \\\n",
       "0  1062_101 2018-01-01      0.879130              0.0                 0.0   \n",
       "1  1062_101 2018-01-08      0.994517              0.0                 0.0   \n",
       "2  1062_101 2018-01-15      1.005513              0.0                 0.0   \n",
       "3  1062_101 2018-01-22      1.000000              0.0                 0.0   \n",
       "4  1062_101 2018-01-29      0.883309              0.0                 0.0   \n",
       "\n",
       "   unit_sales  \n",
       "0       636.0  \n",
       "1       123.0  \n",
       "2       391.0  \n",
       "3       339.0  \n",
       "4       661.0  "
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "past_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>item_id</th>\n",
       "      <th>timestamp</th>\n",
       "      <th>scaled_price</th>\n",
       "      <th>promotion_email</th>\n",
       "      <th>promotion_homepage</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>1062_101</td>\n",
       "      <td>2018-06-11</td>\n",
       "      <td>1.005425</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>1062_101</td>\n",
       "      <td>2018-06-18</td>\n",
       "      <td>1.005454</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>1062_101</td>\n",
       "      <td>2018-06-25</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>1062_101</td>\n",
       "      <td>2018-07-02</td>\n",
       "      <td>1.005513</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>1062_101</td>\n",
       "      <td>2018-07-09</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "     item_id  timestamp  scaled_price  promotion_email  promotion_homepage\n",
       "23  1062_101 2018-06-11      1.005425              0.0                 0.0\n",
       "24  1062_101 2018-06-18      1.005454              0.0                 0.0\n",
       "25  1062_101 2018-06-25      1.000000              0.0                 0.0\n",
       "26  1062_101 2018-07-02      1.005513              0.0                 0.0\n",
       "27  1062_101 2018-07-09      1.000000              0.0                 0.0"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "future_df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can now convert this data into a JSON payload."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_df_to_payload(\n",
    "    past_df,\n",
    "    future_df=None,\n",
    "    prediction_length=1,\n",
    "    freq=\"D\",\n",
    "    target_col=\"target\",\n",
    "    id_col=\"item_id\",\n",
    "    timestamp_col=\"timestamp\",\n",
    "):\n",
    "    \"\"\"\n",
    "    Converts past and future DataFrames into JSON payload format for the Chronos endpoint.\n",
    "\n",
    "    Args:\n",
    "        past_df (pd.DataFrame): Historical data with `target_col`, `timestamp_col`, and `id_col`.\n",
    "        future_df (pd.DataFrame, optional): Future covariates with `timestamp_col` and `id_col`.\n",
    "        prediction_length (int): Number of future time steps to predict.\n",
    "        freq (str): Pandas-compatible frequency of the time series.\n",
    "        target_col (str): Column name for target values.\n",
    "        id_col (str): Column name for item IDs.\n",
    "        timestamp_col (str): Column name for timestamps.\n",
    "\n",
    "    Returns:\n",
    "        dict: JSON payload formatted for the Chronos endpoint.\n",
    "    \"\"\"\n",
    "    past_df = past_df.sort_values([id_col, timestamp_col])\n",
    "    if future_df is not None:\n",
    "        future_df = future_df.sort_values([id_col, timestamp_col])\n",
    "\n",
    "    covariate_cols = list(past_df.columns.drop([target_col, id_col, timestamp_col]))\n",
    "    if covariate_cols and (future_df is None or not set(covariate_cols).issubset(future_df.columns)):\n",
    "        raise ValueError(f\"If past_df contains covariates {covariate_cols}, they should also be present in future_df\")\n",
    "\n",
    "    inputs = []\n",
    "    for item_id, past_group in past_df.groupby(id_col):\n",
    "        target_values = past_group[target_col].tolist()\n",
    "\n",
    "        if len(target_values) < 5:\n",
    "            raise ValueError(f\"Time series '{item_id}' has fewer than 5 observations.\")\n",
    "\n",
    "        series_dict = {\n",
    "            \"target\": target_values,\n",
    "            \"item_id\": str(item_id),\n",
    "            \"start\": past_group[timestamp_col].iloc[0].isoformat(),\n",
    "        }\n",
    "\n",
    "        if covariate_cols:\n",
    "            series_dict[\"past_covariates\"] = past_group[covariate_cols].to_dict(orient=\"list\")\n",
    "            future_group = future_df[future_df[id_col] == item_id]\n",
    "            if len(future_group) != prediction_length:\n",
    "                raise ValueError(\n",
    "                    f\"future_df must contain exactly {prediction_length=} values for each item_id from past_df \"\n",
    "                    f\"(got {len(future_group)=}) for {item_id=}\"\n",
    "                )\n",
    "            series_dict[\"future_covariates\"] = future_group[covariate_cols].to_dict(orient=\"list\")\n",
    "\n",
    "        inputs.append(series_dict)\n",
    "\n",
    "\n",
    "    return {\n",
    "        \"inputs\": inputs,\n",
    "        \"parameters\": {\"prediction_length\": prediction_length, \"freq\": freq},\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "payload = convert_df_to_payload(\n",
    "    past_df,\n",
    "    future_df,\n",
    "    prediction_length=prediction_length,\n",
    "    freq=freq,\n",
    "    target_col=\"unit_sales\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can now send the payload to the endpoint."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "response = predictor.predict(payload)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note how Chronos-Bolt generated predictions for >300 time series in the dataset (with covariates!) in less than 2 seconds, even when running on a small CPU instance.\n",
    "\n",
    "Finally, we can convert the response back to a long-format data frame."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_response_to_df(response, freq=\"D\"):\n",
    "    \"\"\"\n",
    "    Converts a JSON response from the Chronos endpoint into a long-format DataFrame.\n",
    "\n",
    "    Args:\n",
    "        response (dict): JSON response containing forecasts.\n",
    "        freq (str): Pandas-compatible frequency of the time series.\n",
    "\n",
    "    Returns:\n",
    "        pd.DataFrame: Long-format DataFrame with timestamps, item_id, and forecasted values.\n",
    "    \"\"\"\n",
    "    dfs = []\n",
    "    for forecast in response[\"predictions\"]:\n",
    "        forecast_df = pd.DataFrame(forecast).drop(columns=[\"start\"])\n",
    "        forecast_df[\"timestamp\"] = pd.date_range(forecast[\"start\"], freq=freq, periods=len(forecast_df))\n",
    "        dfs.append(forecast_df)\n",
    "    return pd.concat(dfs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>mean</th>\n",
       "      <th>0.1</th>\n",
       "      <th>0.5</th>\n",
       "      <th>0.9</th>\n",
       "      <th>item_id</th>\n",
       "      <th>timestamp</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>315.504037</td>\n",
       "      <td>210.074945</td>\n",
       "      <td>315.504037</td>\n",
       "      <td>487.484408</td>\n",
       "      <td>1062_101</td>\n",
       "      <td>2018-06-11</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>315.364478</td>\n",
       "      <td>200.272695</td>\n",
       "      <td>315.364478</td>\n",
       "      <td>508.145850</td>\n",
       "      <td>1062_101</td>\n",
       "      <td>2018-06-18</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>310.507265</td>\n",
       "      <td>193.902630</td>\n",
       "      <td>310.507265</td>\n",
       "      <td>511.559740</td>\n",
       "      <td>1062_101</td>\n",
       "      <td>2018-06-25</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>317.322873</td>\n",
       "      <td>200.051215</td>\n",
       "      <td>317.322873</td>\n",
       "      <td>525.013830</td>\n",
       "      <td>1062_101</td>\n",
       "      <td>2018-07-02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>319.089405</td>\n",
       "      <td>199.634549</td>\n",
       "      <td>319.089405</td>\n",
       "      <td>534.102518</td>\n",
       "      <td>1062_101</td>\n",
       "      <td>2018-07-09</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         mean         0.1         0.5         0.9   item_id  timestamp\n",
       "0  315.504037  210.074945  315.504037  487.484408  1062_101 2018-06-11\n",
       "1  315.364478  200.272695  315.364478  508.145850  1062_101 2018-06-18\n",
       "2  310.507265  193.902630  310.507265  511.559740  1062_101 2018-06-25\n",
       "3  317.322873  200.051215  317.322873  525.013830  1062_101 2018-07-02\n",
       "4  319.089405  199.634549  319.089405  534.102518  1062_101 2018-07-09"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "forecast_df = convert_response_to_df(response, freq=freq)\n",
    "forecast_df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Clean up the endpoint\n",
    "Don't forget to clean up resources when finished to avoid unnecessary charges."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictor.delete_predictor()"
   ]
  }
 ],
 "metadata": {
  "instance_type": "ml.t3.medium",
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
